import os
import cv2
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


class RAFDBDataset(Dataset):
    """
    RAF-DB Dataset Adapter for Torch
    """

    def __init__(self, phase, base_directory, transform=None):
        """
        :param phase: train or test
        :param base_directory: Path to the basic directory
        :param transform: Transformer that will be applied on the image data
        """
        label_path = os.path.join(base_directory, 'EmoLabel/list_patition_label.txt')
        index = pd.read_csv(label_path, sep=' ', header=None)
        if phase == 'train':
            dataset = index.loc[index[0].str.startswith('train')]
        else:
            dataset = index.loc[index[0].str.startswith('test')]

        self.img_paths = dataset[0].to_list()
        self.img_paths = list(
            map(lambda x:
                os.path.join(
                    base_directory,
                    'Image/aligned/' + x.split('.')[0] + '_aligned.jpg'),
                self.img_paths)
        )

        # 1: crying 3: laughing 2: shouting 0: others
        # 1->14 3->3 2->05 0->26
        # 0->2 1->1 2->0 3->3 4->1 5->2 6->0
        # 0: Surprise 1: Fear 2: Disgust 3: Happiness 4: Sadness 5: Anger 6: Neutral
        self.labels = dataset[1].to_list()
        self.labels = list(map(lambda x: x - 1, self.labels))

        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]

        image = cv2.imread(self.img_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        if self.transform:
            image = self.transform(image)

        return image, label
